{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Convert v1 to v2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/chemprop/chemprop/blob/main/examples/convert_v1_to_v2.ipynb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Install chemprop from GitHub if running in Google Colab\n",
    "import os\n",
    "\n",
    "if os.getenv(\"COLAB_RELEASE_TAG\"):\n",
    "    try:\n",
    "        import chemprop\n",
    "    except ImportError:\n",
    "        !git clone https://github.com/chemprop/chemprop.git\n",
    "        %cd chemprop\n",
    "        !pip install .\n",
    "        %cd examples"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Import packages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from pprint import pprint\n",
    "from pathlib import Path\n",
    "\n",
    "from chemprop.utils.v1_to_v2 import convert_model_dict_v1_to_v2\n",
    "from chemprop.models.model import MPNN\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Change model paths here"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "chemprop_dir = Path.cwd().parent\n",
    "model_v1_input_path =  chemprop_dir / \"tests/data/example_model_v1_regression_mol.pt\" # path to v1 model .pt file\n",
    "model_v2_output_path = Path.cwd() / \"converted_model.ckpt\" # path to save the converted model .ckpt file"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load v1 model .pt file"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_v1_dict = torch.load(model_v1_input_path, weights_only=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['args',\n",
      " 'state_dict',\n",
      " 'data_scaler',\n",
      " 'features_scaler',\n",
      " 'atom_descriptor_scaler',\n",
      " 'bond_descriptor_scaler',\n",
      " 'atom_bond_scaler']\n"
     ]
    }
   ],
   "source": [
    "# Here are all the keys that is stored in v1 model\n",
    "pprint(list(model_v1_dict.keys()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'activation': 'ReLU',\n",
      " 'adding_bond_types': True,\n",
      " 'adding_h': False,\n",
      " 'aggregation': 'mean',\n",
      " 'aggregation_norm': 100,\n",
      " 'atom_constraints': [],\n",
      " 'atom_descriptor_scaling': True,\n",
      " 'atom_descriptors': None,\n",
      " 'atom_descriptors_path': None,\n",
      " 'atom_descriptors_size': 0,\n",
      " 'atom_features_size': 0,\n",
      " 'atom_messages': False,\n",
      " 'atom_targets': [],\n",
      " 'batch_size': 50,\n",
      " 'bias': False,\n",
      " 'bias_solvent': False,\n",
      " 'bond_constraints': [],\n",
      " 'bond_descriptor_scaling': True,\n",
      " 'bond_descriptors': None,\n",
      " 'bond_descriptors_path': None,\n",
      " 'bond_descriptors_size': 0,\n",
      " 'bond_features_size': 0,\n",
      " 'bond_targets': [],\n",
      " 'cache_cutoff': 10000,\n",
      " 'checkpoint_dir': None,\n",
      " 'checkpoint_frzn': None,\n",
      " 'checkpoint_path': None,\n",
      " 'checkpoint_paths': None,\n",
      " 'class_balance': False,\n",
      " 'config_path': None,\n",
      " 'constraints_path': None,\n",
      " 'crossval_index_dir': None,\n",
      " 'crossval_index_file': None,\n",
      " 'crossval_index_sets': None,\n",
      " 'cuda': False,\n",
      " 'data_path': '/Users/hwpang/Software/chemprop/tests/data/regression.csv',\n",
      " 'data_weights_path': None,\n",
      " 'dataset_type': 'regression',\n",
      " 'depth': 3,\n",
      " 'depth_solvent': 3,\n",
      " 'device': device(type='cpu'),\n",
      " 'dropout': 0.0,\n",
      " 'empty_cache': False,\n",
      " 'ensemble_size': 1,\n",
      " 'epochs': 1,\n",
      " 'evidential_regularization': 0,\n",
      " 'explicit_h': False,\n",
      " 'extra_metrics': [],\n",
      " 'features_generator': None,\n",
      " 'features_only': False,\n",
      " 'features_path': None,\n",
      " 'features_scaling': True,\n",
      " 'features_size': None,\n",
      " 'ffn_hidden_size': 300,\n",
      " 'ffn_num_layers': 2,\n",
      " 'final_lr': 0.0001,\n",
      " 'folds_file': None,\n",
      " 'freeze_first_only': False,\n",
      " 'frzn_ffn_layers': 0,\n",
      " 'gpu': None,\n",
      " 'grad_clip': None,\n",
      " 'hidden_size': 300,\n",
      " 'hidden_size_solvent': 300,\n",
      " 'ignore_columns': None,\n",
      " 'init_lr': 0.0001,\n",
      " 'is_atom_bond_targets': False,\n",
      " 'keeping_atom_map': False,\n",
      " 'log_frequency': 10,\n",
      " 'loss_function': 'mse',\n",
      " 'max_data_size': None,\n",
      " 'max_lr': 0.001,\n",
      " 'metric': 'rmse',\n",
      " 'metrics': ['rmse'],\n",
      " 'minimize_score': True,\n",
      " 'mpn_shared': False,\n",
      " 'multiclass_num_classes': 3,\n",
      " 'no_adding_bond_types': False,\n",
      " 'no_atom_descriptor_scaling': False,\n",
      " 'no_bond_descriptor_scaling': False,\n",
      " 'no_cache_mol': False,\n",
      " 'no_cuda': False,\n",
      " 'no_features_scaling': False,\n",
      " 'no_shared_atom_bond_ffn': False,\n",
      " 'num_folds': 1,\n",
      " 'num_lrs': 1,\n",
      " 'num_tasks': 1,\n",
      " 'num_workers': 8,\n",
      " 'number_of_molecules': 1,\n",
      " 'overwrite_default_atom_features': False,\n",
      " 'overwrite_default_bond_features': False,\n",
      " 'phase_features_path': None,\n",
      " 'pytorch_seed': 0,\n",
      " 'quiet': False,\n",
      " 'reaction': False,\n",
      " 'reaction_mode': 'reac_diff',\n",
      " 'reaction_solvent': False,\n",
      " 'resume_experiment': False,\n",
      " 'save_dir': '/Users/hwpang/Software/test_chemprop_v1_to_v2/fold_0',\n",
      " 'save_preds': False,\n",
      " 'save_smiles_splits': True,\n",
      " 'seed': 0,\n",
      " 'separate_test_atom_descriptors_path': None,\n",
      " 'separate_test_bond_descriptors_path': None,\n",
      " 'separate_test_constraints_path': None,\n",
      " 'separate_test_features_path': None,\n",
      " 'separate_test_path': None,\n",
      " 'separate_test_phase_features_path': None,\n",
      " 'separate_val_atom_descriptors_path': None,\n",
      " 'separate_val_bond_descriptors_path': None,\n",
      " 'separate_val_constraints_path': None,\n",
      " 'separate_val_features_path': None,\n",
      " 'separate_val_path': None,\n",
      " 'separate_val_phase_features_path': None,\n",
      " 'shared_atom_bond_ffn': True,\n",
      " 'show_individual_scores': False,\n",
      " 'smiles_columns': ['smiles'],\n",
      " 'spectra_activation': 'exp',\n",
      " 'spectra_phase_mask': None,\n",
      " 'spectra_phase_mask_path': None,\n",
      " 'spectra_target_floor': 1e-08,\n",
      " 'split_key_molecule': 0,\n",
      " 'split_sizes': [0.8, 0.1, 0.1],\n",
      " 'split_type': 'random',\n",
      " 'target_columns': None,\n",
      " 'target_weights': None,\n",
      " 'task_names': ['logSolubility'],\n",
      " 'test': False,\n",
      " 'test_fold_index': None,\n",
      " 'train_data_size': 400,\n",
      " 'undirected': False,\n",
      " 'use_input_features': False,\n",
      " 'val_fold_index': None,\n",
      " 'warmup_epochs': 2.0,\n",
      " 'weights_ffn_num_layers': 2}\n"
     ]
    }
   ],
   "source": [
    "# Here are the input arguments that is stored in v1 model\n",
    "pprint(model_v1_dict['args'].__dict__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['encoder.encoder.0.cached_zero_vector',\n",
      " 'encoder.encoder.0.W_i.weight',\n",
      " 'encoder.encoder.0.W_h.weight',\n",
      " 'encoder.encoder.0.W_o.weight',\n",
      " 'encoder.encoder.0.W_o.bias',\n",
      " 'readout.1.weight',\n",
      " 'readout.1.bias',\n",
      " 'readout.4.weight',\n",
      " 'readout.4.bias']\n"
     ]
    }
   ],
   "source": [
    "# Here are the state_dict that is stored in v1 model\n",
    "pprint(list(model_v1_dict['state_dict'].keys()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Convert loaded v1 model dictionary into v2 model dictionary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_v2_dict = convert_model_dict_v1_to_v2(model_v1_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['epoch',\n",
      " 'global_step',\n",
      " 'pytorch-lightning_version',\n",
      " 'state_dict',\n",
      " 'loops',\n",
      " 'callbacks',\n",
      " 'optimizer_states',\n",
      " 'lr_schedulers',\n",
      " 'hparams_name',\n",
      " 'hyper_parameters']\n"
     ]
    }
   ],
   "source": [
    "# Here are all the keys in the converted model\n",
    "pprint(list(model_v2_dict.keys()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['message_passing.W_i.weight',\n",
      " 'message_passing.W_h.weight',\n",
      " 'message_passing.W_o.weight',\n",
      " 'message_passing.W_o.bias',\n",
      " 'predictor.ffn.0.0.weight',\n",
      " 'predictor.ffn.0.0.bias',\n",
      " 'predictor.ffn.1.2.weight',\n",
      " 'predictor.ffn.1.2.bias',\n",
      " 'predictor.output_transform.mean',\n",
      " 'predictor.output_transform.scale',\n",
      " 'predictor.criterion.task_weights']\n"
     ]
    }
   ],
   "source": [
    "# Here are all the keys in the converted state_dict\n",
    "pprint(list(model_v2_dict['state_dict'].keys()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['batch_norm',\n",
      " 'metrics',\n",
      " 'warmup_epochs',\n",
      " 'init_lr',\n",
      " 'max_lr',\n",
      " 'final_lr',\n",
      " 'message_passing',\n",
      " 'agg',\n",
      " 'predictor']\n"
     ]
    }
   ],
   "source": [
    "# Here are all the keys in the converted hyper_parameters\n",
    "pprint(list(model_v2_dict['hyper_parameters'].keys()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Save"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(model_v2_dict, model_v2_output_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load converted model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "mpnn = MPNN.load_from_checkpoint(model_v2_output_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MPNN(\n",
       "  (message_passing): BondMessagePassing(\n",
       "    (W_i): Linear(in_features=147, out_features=300, bias=False)\n",
       "    (W_h): Linear(in_features=300, out_features=300, bias=False)\n",
       "    (W_o): Linear(in_features=433, out_features=300, bias=True)\n",
       "    (dropout): Dropout(p=0.0, inplace=False)\n",
       "    (tau): ReLU()\n",
       "    (V_d_transform): Identity()\n",
       "    (graph_transform): Identity()\n",
       "  )\n",
       "  (agg): MeanAggregation()\n",
       "  (bn): Identity()\n",
       "  (predictor): RegressionFFN(\n",
       "    (ffn): MLP(\n",
       "      (0): Sequential(\n",
       "        (0): Linear(in_features=300, out_features=300, bias=True)\n",
       "      )\n",
       "      (1): Sequential(\n",
       "        (0): ReLU()\n",
       "        (1): Dropout(p=0.0, inplace=False)\n",
       "        (2): Linear(in_features=300, out_features=1, bias=True)\n",
       "      )\n",
       "    )\n",
       "    (criterion): MSE(task_weights=[[1.0]])\n",
       "    (output_transform): UnscaleTransform()\n",
       "  )\n",
       "  (X_d_transform): Identity()\n",
       "  (metrics): ModuleList(\n",
       "    (0): RMSE(task_weights=[[1.0]])\n",
       "    (1): MSE(task_weights=[[1.0]])\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# now visually check the converted model is what is expected\n",
    "mpnn"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "chemprop",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
